import argparse
import json
import logging
from multiprocessing import Pool
import pandas as pd
import os
import tqdm
import urllib.request
import zipfile

logging.basicConfig(level=logging.INFO)
log = logging.getLogger("coco")


def get_args():
    """Parse commandline."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset-dir", default="./coco-2014", help="Dataset download location"
    )
    parser.add_argument(
        "--tsv-path", default=None, help="Precomputed tsv file location"
    )
    parser.add_argument(
        "--max-images",
        default=5000,
        type=int,
        help="Maximun number of images to download",
    )
    parser.add_argument(
        "--num-workers",
        default=1,
        type=int,
        help="Number of processes to download images",
    )
    parser.add_argument(
        "--allow-duplicate-images",
        action="store_true",
        help="Allow mulple captions per image",
    )
    parser.add_argument(
        "--latents-path-torch",
        default="latents.pt",
        type=str,
        help="Path to pytorch latents",
    )
    parser.add_argument(
        "--latents-path-numpy",
        default="latents.npy",
        type=str,
        help="Path to numpy latents",
    )
    parser.add_argument(
        "--seed", type=int, default=2023, help="Seed to choose the dataset"
    )
    parser.add_argument(
        "--keep-raw",
        action="store_true",
        help="Keep raw folder")
    parser.add_argument(
        "--download-images", action="store_true", help="Download the calibration set"
    )

    args = parser.parse_args()
    return args


def download_img(args):
    img_url, target_folder, file_name = args
    if os.path.exists(target_folder + file_name):
        log.warning(f"Image {file_name} found locally, skipping download")
    else:
        urllib.request.urlretrieve(img_url, target_folder + file_name)


if __name__ == "__main__":
    args = get_args()
    dataset_dir = os.path.abspath(args.dataset_dir)
    # Check if the annotation dataframe is there
    if os.path.exists(f"{dataset_dir}/captions/captions_source.tsv"):
        df_annotations = pd.read_csv(
            f"{dataset_dir}/captions/captions_source.tsv", sep="\t"
        )
        df_annotations = df_annotations.iloc[: args.max_images]
    elif os.path.exists(f"{dataset_dir}/../captions_source.tsv"):
        os.makedirs(f"{dataset_dir}/captions/", exist_ok=True)
        os.system(
            f"cp {dataset_dir}/../captions_source.tsv {dataset_dir}/captions/")
        df_annotations = pd.read_csv(
            f"{dataset_dir}/captions/captions_source.tsv", sep="\t"
        )
        df_annotations = df_annotations.iloc[: args.max_images]
    elif args.tsv_path is not None and os.path.exists(f"{args.tsv_path}"):
        file_name = args.tsv_path.split("/")[-1]
        os.makedirs(f"{dataset_dir}/captions/", exist_ok=True)
        os.system(f"cp {args.tsv_path} {dataset_dir}/captions/")
        df_annotations = pd.read_csv(
            f"{dataset_dir}/captions/{file_name}", sep="\t")
        df_annotations = df_annotations.iloc[: args.max_images]
    else:
        # Check if raw annotations file already exist
        if not os.path.exists(
                f"{dataset_dir}/raw/annotations/captions_val2014.json"):
            # Download annotations
            os.makedirs(f"{dataset_dir}/raw/", exist_ok=True)
            os.makedirs(f"{dataset_dir}/download_aux/", exist_ok=True)
            os.system(
                f"cd {dataset_dir}/download_aux/ && \
                    wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip --show-progress"
            )

            # Unzip file
            with zipfile.ZipFile(
                f"{dataset_dir}/download_aux/annotations_trainval2014.zip", "r"
            ) as zip_ref:
                zip_ref.extractall(f"{dataset_dir}/raw/")

        # Move captions to target folder
        os.makedirs(f"{dataset_dir}/captions/", exist_ok=True)
        os.system(
            f"mv {dataset_dir}/raw/annotations/captions_val2014.json {dataset_dir}/captions/"
        )
        if not args.keep_raw:
            os.system(f"rm -rf {dataset_dir}/raw")
        os.system(f"rm -rf {dataset_dir}/download_aux")
        # Convert to dataframe format and extract the relevant fields
        with open(f"{dataset_dir}/captions/captions_val2014.json") as f:
            captions = json.load(f)
            annotations = captions["annotations"]
            images = captions["images"]
        df_annotations = pd.DataFrame(annotations)
        df_images = pd.DataFrame(images)
        if not args.allow_duplicate_images:
            df_annotations = df_annotations.drop_duplicates(
                subset=["image_id"], keep="first"
            )
        # Sort, shuffle and choose the final dataset
        df_annotations = df_annotations.sort_values(by=["id"])
        df_annotations = df_annotations.sample(
            frac=1, random_state=args.seed
        ).reset_index(drop=True)
        df_annotations = df_annotations.iloc[: args.max_images]
        df_annotations["caption"] = df_annotations["caption"].apply(
            lambda x: x.replace("\n", "").strip()
        )
        df_annotations = (
            df_annotations.merge(
                df_images, how="inner", left_on="image_id", right_on="id"
            )
            .drop(["id_y"], axis=1)
            .rename(columns={"id_x": "id"})
            .sort_values(by=["id"])
            .reset_index(drop=True)
        )
    # Download images
    if args.download_images:
        os.makedirs(f"{dataset_dir}/validation/data/", exist_ok=True)
        tasks = [
            (row["coco_url"],
             f"{dataset_dir}/validation/data/",
             row["file_name"])
            for i, row in df_annotations.iterrows()
        ]
        pool = Pool(processes=args.num_workers)
        [
            _
            for _ in tqdm.tqdm(
                pool.imap_unordered(download_img, tasks), total=len(tasks)
            )
        ]
    # Finalize annotations
    df_annotations[
        ["id", "image_id", "caption", "height", "width", "file_name", "coco_url"]
    ].to_csv(f"{dataset_dir}/captions/captions.tsv", sep="\t", index=False)

    if os.path.exists(args.latents_path_torch):
        os.makedirs(f"{dataset_dir}/latents/", exist_ok=True)
        os.system(f"cp {args.latents_path_torch} {dataset_dir}/latents/")

    if os.path.exists(args.latents_path_numpy):
        os.makedirs(f"{dataset_dir}/latents/", exist_ok=True)
        os.system(f"cp {args.latents_path_numpy} {dataset_dir}/latents/")
